feat: ring-based all_gather with workspace preamble#490
Closed
feat: ring-based all_gather with workspace preamble#490
Conversation
Ring-based all-gather where each rank forwards data around a ring instead of all-pairs writes. In each step, rank r reads a shard from its output buffer and iris.store()s it to the next rank. After N-1 steps every rank has all shards. Uses per-tile flag-based producer/consumer sync. Benefits: O(1) peer writes per step (vs O(N) fan-out), avoids memory-controller contention for large messages. - New kernel: persistent_all_gather_ring - Config: all_gather_variant="ring" - Tests: test_all_gather_ring with multiple shapes and dtypes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Mirrors the proven all_reduce ring kernel: uses a separate ring_buffer on the symmetric heap for receiving data, simple 0/1 flag toggling, no step-counting. Each step: send current data to next rank's ring_buffer, wait for predecessor to write into our ring_buffer, copy received data to correct output slot, forward it next step. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Triton's % operator uses C truncated-division semantics where (-1) % 4 = -1, not Python's floored-division where (-1) % 4 = 3. This caused recv_rank_idx and source_rank_idx to be negative for ranks where group_rank < step, writing to invalid output locations. Fix: add world_size before the modulo to ensure non-negative operands. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Align with codebase convention (README, benchmarks, docs all use ctx). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pre-allocates ring_buffer and flags on the symmetric heap once. Pass the workspace to all_gather to avoid per-call allocation overhead (~13ms saved per call in benchmarks). Follows the same pattern as AllReduceWorkspace / all_reduce_preamble. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Wraps remote operations (iris.atomic_cas, iris.store, iris.atomic_xchg) with DeviceTracing record_event_start/end for per-operation timing. Zero overhead when tracing=False via constexpr dead-code elimination. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The tracing API's record_event_start requires block pointers (for tl.min reduction), not scalar pointers. Restructure to trace full ring steps using the ring_buffer tile address instead of individual atomic ops. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… flags Replace O(tiles * steps) remote atomic operations with O(steps) by using a global barrier. All CUs write tiles in parallel for each step, then synchronize via a local counter + single remote signal. This eliminates the massive atomic contention that caused ~10x slowdown vs RCCL. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Previous global barrier had two bugs: 1. Race between step reset and CU entry (fixed: monotonic counter) 2. Write-after-read hazard on single ring_buffer (fixed: double-buffer) Step k reads from buf[k%2], writes to next rank's buf[(k+1)%2]. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace per-tile flag handshake with per-chunk handshake to amortize atomic synchronization overhead. The shard is split into NUM_CHUNKS row-bands, each with one flag. CUs process all tiles within a chunk before signaling, reducing atomics from 4*total_tiles*steps to 4*num_chunks*steps. Different CUs work on different chunks at different ring steps, creating a pipeline that keeps XGMI links busy. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Defer local flag reset until AFTER reading ring_buffer for forwarding in the next step's send phase. Previously, resetting the flag at the end of the receive phase allowed the predecessor to overwrite ring_buffer before we read it for forwarding, causing data corruption. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
all_gathervariant using flag-based producer/consumer syncAllGatherWorkspace+all_gather_preamble()for pre-allocated scratch buffers%uses C truncated-division, not Python floored)shmem→ctxin all_gather module and ring testsTest plan
test_all_gather_ringcovers 4 shapes × 3 dtypes = 12 test cases🤖 Generated with Claude Code